{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 10\n",
    "# We don't use the whole dataset for efficiency purpose, but feel free to increase these numbers\n",
    "n_train_items = 640\n",
    "n_test_items = 640"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part X - 对MNIST进行安全的训练和评估\n",
    "\n",
    "在构建机器学习即服务解决方案（MLaaS）时，公司可能需要请求其他合作伙伴访问数据以训练其模型。在卫生或金融领域，模型和数据都非常关键：模型参数是业务资产，而数据是严格监管的个人数据。\n",
    "\n",
    "在这种情况下，一种可能的解决方案是对模型和数据都进行加密，并在加密后的值上训练机器学习模型。例如，这保证了公司不会访问患者的病历，并且医疗机构将无法观察他们贡献的模型。存在几种允许对加密数据进行计算的加密方案，其中包括安全多方计算（SMPC），同态加密（FHE / SHE）和功能加密（FE）。我们将在这里集中讨论多方计算（已在教程5中进行了介绍），它由私有加性共享组成，并依赖于加密协议SecureNN和SPDZ。\n",
    "\n",
    "本教程的确切设置如下：考虑您是服务器，并且您想对$n$个工作机持有的某些数据进行模型训练。服务器机密共享他的模型，并将每个共享发送给工作机。工作机秘密共享他们的数据并在他们之间交换。在我们将要研究的配置中，有2个工作机：alice和bob。交换共享后，他们每个人现在拥有自己的共享，另一工作机的数据共享和模型共享。现在，计算可以开始使用适当的加密协议来私下训练模型。训练模型后，所有共享都可以发送回服务器以对其进行解密。下图对此进行了说明："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![SMPC Illustration](https://github.com/OpenMined/PySyft/raw/11c85a121a1a136e354945686622ab3731246084/examples/tutorials/material/smpc_illustration.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了举例说明这个过程，让我们假设alice和bob都拥有MNIST数据集的一部分，然后训练一个模型来执行数字分类！\n",
    "\n",
    "作者:\n",
    "- Théo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel) · GitHub: [@LaRiffle](https://github.com/LaRiffle)\n",
    "\n",
    "中文版译者：\n",
    "- Hou Wei - github：[@dljgs1](https://github.com/dljgs1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. 在MNIST上进行加密训练的demo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导包以及训练配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此类描述了训练的所有超参数。 请注意，它们在这里都是公开的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size = 64\n",
    "        self.test_batch_size = 64\n",
    "        self.epochs = epochs\n",
    "        self.lr = 0.02\n",
    "        self.seed = 1\n",
    "        self.log_interval = 1 # Log info at each batch\n",
    "        self.precision_fractional = 3\n",
    "\n",
    "args = Arguments()\n",
    "\n",
    "_ = torch.manual_seed(args.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这是PySyft的进口商品。 我们连接到两个名为`alice`和`bob`的远程工作机，并请求另一个名为`crypto_provider`的工作机，它提供了我们可能需要的所有加密原语。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy  # import the Pysyft library\n",
    "hook = sy.TorchHook(torch)  # hook PyTorch to add extra functionalities like Federated and Encrypted Learning\n",
    "\n",
    "# simulation functions\n",
    "def connect_to_workers(n_workers):\n",
    "    return [\n",
    "        sy.VirtualWorker(hook, id=f\"worker{i+1}\")\n",
    "        for i in range(n_workers)\n",
    "    ]\n",
    "def connect_to_crypto_provider():\n",
    "    return sy.VirtualWorker(hook, id=\"crypto_provider\")\n",
    "\n",
    "workers = connect_to_workers(n_workers=2)\n",
    "crypto_provider = connect_to_crypto_provider()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 获取访问权限和秘密共享数据\n",
    "\n",
    "在这里，我们使用一个效用函数来模拟以下行为：我们假设MNIST数据集分布在各个部分中，每个部分都由我们的一个工作机持有。 然后，工作机将其数据分批拆分，并在彼此之间秘密共享其数据。 返回的最终对象是这些秘密共享批次上的可迭代对象，我们称之为“私有数据加载器”。 请注意，在此过程中，本地工作人员（因此我们）从未访问过数据。\n",
    "\n",
    "我们像往常一样获得了训练和测试私有数据集，并且输入和标签都是秘密共享的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_private_data_loaders(precision_fractional, workers, crypto_provider):\n",
    "    \n",
    "    def one_hot_of(index_tensor):\n",
    "        \"\"\"\n",
    "        Transform to one hot tensor\n",
    "        \n",
    "        Example:\n",
    "            [0, 3, 9]\n",
    "            =>\n",
    "            [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
    "             [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
    "             [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]\n",
    "            \n",
    "        \"\"\"\n",
    "        onehot_tensor = torch.zeros(*index_tensor.shape, 10) # 10 classes for MNIST\n",
    "        onehot_tensor = onehot_tensor.scatter(1, index_tensor.view(-1, 1), 1)\n",
    "        return onehot_tensor\n",
    "        \n",
    "    def secret_share(tensor):\n",
    "        \"\"\"\n",
    "        Transform to fixed precision and secret share a tensor\n",
    "        \"\"\"\n",
    "        return (\n",
    "            tensor\n",
    "            .fix_precision(precision_fractional=precision_fractional)\n",
    "            .share(*workers, crypto_provider=crypto_provider, requires_grad=True)\n",
    "        )\n",
    "    \n",
    "    transformation = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,))\n",
    "    ])\n",
    "    \n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data', train=True, download=True, transform=transformation),\n",
    "        batch_size=args.batch_size\n",
    "    )\n",
    "    \n",
    "    private_train_loader = [\n",
    "        (secret_share(data), secret_share(one_hot_of(target)))\n",
    "        for i, (data, target) in enumerate(train_loader)\n",
    "        if i < n_train_items / args.batch_size\n",
    "    ]\n",
    "    \n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data', train=False, download=True, transform=transformation),\n",
    "        batch_size=args.test_batch_size\n",
    "    )\n",
    "    \n",
    "    private_test_loader = [\n",
    "        (secret_share(data), secret_share(target.float()))\n",
    "        for i, (data, target) in enumerate(test_loader)\n",
    "        if i < n_test_items / args.test_batch_size\n",
    "    ]\n",
    "    \n",
    "    return private_train_loader, private_test_loader\n",
    "    \n",
    "    \n",
    "private_train_loader, private_test_loader = get_private_data_loaders(\n",
    "    precision_fractional=args.precision_fractional,\n",
    "    workers=workers,\n",
    "    crypto_provider=crypto_provider\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型规格\n",
    "\n",
    "这是我们将使用的模型，它是一个相当简单的模型，但是[已证明在MNIST上表现良好](https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(28 * 28, 128)\n",
    "        self.fc2 = nn.Linear(128, 64)\n",
    "        self.fc3 = nn.Linear(64, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 28 * 28)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练和测试功能\n",
    "\n",
    "训练几乎像往常一样进行，真正的区别是我们不能使用像负对数可能性（PyTorch中的F.nll_loss）之类的损失，因为使用SMPC再现这些功能相当复杂。相反，我们使用更简单的均方误差损失。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(args, model, private_train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(private_train_loader): # <-- now it is a private dataset\n",
    "        start_time = time.time()\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        output = model(data)\n",
    "        \n",
    "        # loss = F.nll_loss(output, target)  <-- not possible here\n",
    "        batch_size = output.shape[0]\n",
    "        loss = ((output - target)**2).sum().refresh()/batch_size\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "\n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            loss = loss.get().float_precision()\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\tTime: {:.3f}s'.format(\n",
    "                epoch, batch_idx * args.batch_size, len(private_train_loader) * args.batch_size,\n",
    "                100. * batch_idx / len(private_train_loader), loss.item(), time.time() - start_time))\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试功能不变！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(args, model, private_test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in private_test_loader:\n",
    "            start_time = time.time()\n",
    "            \n",
    "            output = model(data)\n",
    "            pred = output.argmax(dim=1)\n",
    "            correct += pred.eq(target.view_as(pred)).sum()\n",
    "\n",
    "    correct = correct.get().float_precision()\n",
    "    print('\\nTest set: Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        correct.item(), len(private_test_loader)* args.test_batch_size,\n",
    "        100. * correct.item() / (len(private_test_loader) * args.test_batch_size)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 让我们开始训练吧！\n",
    "\n",
    "关于这里发生的事情的一些注意事项。首先，我们秘密共享所有工作机的所有模型参数。其次，我们将优化程序的超参数转换为固定精度。注意，我们不需要秘密共享它们，因为它们在我们的上下文中是公共的，但是当秘密共享值存在于有限域中时，我们仍然需要使用`.fix_precision`将它们移入有限域中，以便执行一致的操作像权重更新: $W \\leftarrow W - \\alpha * \\Delta W$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net()\n",
    "model = model.fix_precision().share(*workers, crypto_provider=crypto_provider, requires_grad=True)\n",
    "\n",
    "optimizer = optim.SGD(model.parameters(), lr=args.lr)\n",
    "optimizer = optimizer.fix_precision() \n",
    "\n",
    "for epoch in range(1, args.epochs + 1):\n",
    "    train(args, model, private_train_loader, optimizer, epoch)\n",
    "    test(args, model, private_test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这！使用MNIST数据集的一小部分，使用100％加密的训练，您只能获得75％的准确性！"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. 讨论区\n",
    "\n",
    "通过分析我们刚刚做的事情，让我们更仔细地了解加密训练的功能。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1 计算时间\n",
    "\n",
    "第一件事显然是运行时间！您肯定已经注意到，它比明文训练要慢得多。特别是，在1批64项上进行一次迭代需要3.2s，而在纯PyTorch中只有13 ms。尽管这似乎是一个阻塞程序，但请回想一下，这里的所有事情都是远程发生的，并且是在加密的世界中发生的：没有单个数据项被公开。更具体地说，处理一项的时间是50ms，这还不错。真正的问题是分析何时需要加密训练以及何时仅加密预测就足够了。例如，在生产就绪的情况下，完全可以接受50毫秒执行预测！\n",
    "\n",
    "一个主要的瓶颈是昂贵的激活功能的使用：SMPC的relu激活非常昂贵，因为它使用私有比较和SecureNN协议。例如，如果我们用二次激活代替relu，就像在CryptoNets等加密计算的几篇论文中所做的那样，我们将从3.2s降到1.2s。\n",
    "\n",
    "通常，关键思想是仅加密必需的内容，本教程向您展示它可以多么简单。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 SMPC的反向传播\n",
    "\n",
    "您可能会想知道，尽管我们在有限域中使用整数，但我们如何执行反向传播和梯度更新。为此，我们开发了一个新的syft张量，称为AutogradTensor。 尽管您可能没有看过本教程，但它还是大量使用它！ 让我们通过打印模型的重量进行检查："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fc3.bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "和一个数据项："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_batch, input_data = 0, 0\n",
    "private_train_loader[first_batch][input_data]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "正如您所看到的，AutogradTensor在那里！ 它位于Torch包装器和FixedPrecisionTensor之间，这表明值现在位于有限域中。此AutogradTensor的目标是在对加密值进行操作时存储计算图。这很有用，因为在向后调用反向传播时，此AutogradTensor会覆盖所有与加密计算不兼容的向后函数，并指示如何计算这些梯度。 例如，对于使用Beaver三元组技巧完成的乘法，我们不想再对技巧进行区分，因为区分乘法应该非常容易：$\\partial_b(a\\cdot b)= \\cdot \\partial b$。 例如，这是我们描述如何计算这些梯度的方法：\n",
    "\n",
    "```python\n",
    "class MulBackward(GradFunc):\n",
    "    def __init__(self, self_, other):\n",
    "        super().__init__(self, self_, other)\n",
    "        self.self_ = self_\n",
    "        self.other = other\n",
    "\n",
    "    def gradient(self, grad):\n",
    "        grad_self_ = grad * self.other\n",
    "        grad_other = grad * self.self_ if type(self.self_) == type(self.other) else None\n",
    "        return (grad_self_, grad_other)\n",
    "```\n",
    "\n",
    "如果您想知道更多我们如何实现梯度的，可以查看`tensors / interpreters / gradients.py`。\n",
    "\n",
    "就计算图而言，这意味着该图的副本保留在本地，并且协调正向传递的服务器还提供有关如何进行反向传递的指令。 在我们的环境中，这是一个完全正确的假设。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.3 安全保障\n",
    "\n",
    "最后，让我们给出一些有关我们在此处实现的安全性的提示：我们正在考虑的对手**诚实但好奇**：这意味着对手无法通过运行此协议来了解有关数据的任何信息，但是恶意的对手仍可能偏离协议，例如尝试破坏共享以破坏计算。在这样的SMPC计算（包括私有比较）中针对恶意对手的安全性仍然是一个未解决的问题。\n",
    "\n",
    "此外，即使安全多方计算确保不访问训练数据，此处仍然存在来自纯文本世界的许多威胁。例如，当您可以向模型提出请求时（在MLaaS的上下文中），您可以获得可能泄露有关训练数据集信息的预测。特别是，您没有针对成员资格攻击的任何保护措施，这是对机器学习服务的常见攻击，在这种攻击中，对手想确定数据集中是否使用了特定项目。除此之外，其他攻击，例如意外的记忆过程（学习有关数据项特定特征的模型），模型反演或提取仍然可能。\n",
    "\n",
    "对上述许多威胁有效的一种通用解决方案是添加差异隐私。它可以与安全的多方计算完美地结合在一起，并且可以提供非常有趣的安全性保证。我们目前正在研究几种实现方式，并希望提出一个将两者结合起来的示例！"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 恭喜!!! 是时候加入社区了!\n",
    "\n",
    "祝贺您完成本笔记本教程！ 如果您喜欢此方法，并希望加入保护隐私、去中心化AI和AI供应链（数据）所有权的运动，则可以通过以下方式做到这一点！\n",
    "\n",
    "### 给 PySyft 加星\n",
    "\n",
    "帮助我们的社区的最简单方法是仅通过给GitHub存储库加注星标！ 这有助于提高人们对我们正在构建的出色工具的认识。\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### 选择我们的教程\n",
    "\n",
    "我们编写了非常不错的教程，以更好地了解联合学习和隐私保护学习的外观，以及我们如何为实现这一目标添砖加瓦。\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### 加入我们的 Slack!\n",
    "\n",
    "保持最新进展的最佳方法是加入我们的社区！ 您可以通过填写以下表格来做到这一点[http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### 加入代码项目!\n",
    "\n",
    "对我们的社区做出贡献的最好方法是成为代码贡献者！ 您随时可以转到PySyft GitHub的Issue页面并过滤“projects”。这将向您显示所有概述，选择您可以加入的项目！如果您不想加入项目，但是想做一些编码，则还可以通过搜索标记为“good first issue”的GitHub问题来寻找更多的“一次性”微型项目。\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### 捐赠\n",
    "\n",
    "如果您没有时间为我们的代码库做贡献，但仍想提供支持，那么您也可以成为Open Collective的支持者。所有捐款都将用于我们的网络托管和其他社区支出，例如黑客马拉松和聚会！\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
